import jax
import jax.numpy as jnp
import jax.random as random

def load_data(debug=False):
    n = 500
    temp = jnp.linspace(-1,1,n)
    x = -jnp.sin(0.5*jnp.pi*temp)
    y = jnp.cos(0.5*jnp.pi*temp)
    x = jnp.stack([x,y],axis=1)
    x += 0.1*random.normal(random.PRNGKey(0),shape=x.shape)
    mean = x.mean(axis=0)
    x = x - mean
    if debug:
        print("Data shape:      ",x.shape)  
        print("Data:            ",x.min(),x.max(),x.mean(),x.std())
    return x,mean


class ARCHDataLoader():
    def __init__(self,x,dij,args,key):
        assert args.batch_size<=len(x)
        self.x = x
        self.n = len(x)
        self.batch_size = args.batch_size
        self.n_batches = self.n//self.batch_size
        self.shuffle = args.shuffle
        self.key1 = key
        self.dij = dij
        self.idx_array = jnp.arange(self.n)

    def __iter__(self):
        batch_start = 0
        while batch_start+self.batch_size < self.n:
            if self.shuffle:
                self.key1, key1= random.split(self.key1)
                self.idx_array = random.permutation(key1, self.idx_array)
            idxs = self.idx_array[batch_start:batch_start+self.batch_size]
            batch_x = self.x[idxs]
            batch_d_ij = self.dij[idxs][:,idxs]
            yield batch_x,batch_d_ij
            batch_start += self.batch_size

    def __len__(self):
        return self.n_batches
    
class ARCHGenDataLoader():
    def __init__(self,x,args,key):
        assert args.batch_size<=len(x)
        self.x = x
        self.n = len(x)
        self.batch_size = args.batch_size
        self.n_batches = self.n//self.batch_size
        self.shuffle = args.shuffle
        self.key1 = key
        self.idx_array = jnp.arange(self.n)

    def __iter__(self):
        batch_start = 0
        while batch_start+self.batch_size < self.n:
            if self.shuffle:
                self.key1, key1= random.split(self.key1)
                self.idx_array = random.permutation(key1, self.idx_array)
            idxs = self.idx_array[batch_start:batch_start+self.batch_size]
            batch_x = self.x[idxs]
            yield batch_x
            batch_start += self.batch_size

    def __len__(self):
        return self.n_batches